import torch
import torch.utils.data
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers
import os
import json, csv
import time
from tqdm.auto import tqdm
from einops import rearrange, reduce
import numpy as np
import trimesh
import warnings
import open3d as o3d

from models import *
from utils import mesh, evaluate
from utils.reconstruct import *
from utils.test_reconstruct import *
from diff_utils.helpers import *

from dataloader.pc_loader import PCloader
import pyvista as pv
from dataloader.pc_loader import PCloader
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures import Meshes
import os
import torch

@torch.no_grad()
def test_modulations():
    
    test_split = [line.strip() for line in open(specs["TrainSplit"], "r").readlines()]
    test_dataset = PCloader(specs["DataSource"], test_split, pc_size=specs.get("PCsize", 1024), return_filename=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)

    ckpt = "{}.ckpt".format(args.resume) if args.resume == 'last' else "epoch={}.ckpt".format(args.resume)
    resume = os.path.join(args.exp_dir, ckpt)
    model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()
    cd_file = os.path.join(recon_dir, "cd.csv")

    with tqdm(test_dataloader) as pbar:
        pred_cd_list = []
        for idx, data in enumerate(pbar):
            pbar.set_description("Files evaluated: {}/{}".format(idx, len(test_dataloader)))

            point_cloud = data['point_cloud']

            filename = data['file_name']
            filename = filename[0]  

            cls_name = filename.split("/")[-4]
            mesh_name = filename.split("/")[-1].split('.')[0]

            outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
            os.makedirs(outdir, exist_ok=True)
            mesh_filename = os.path.join(outdir, "reconstruct")
            ply_filename = os.path.join(outdir, 'point_cloud.ply')

            points = point_cloud.squeeze(0).numpy()
            with open(ply_filename, "w") as f:
                f.write(
                    "ply\nformat ascii 1.0\nelement vertex {}\nproperty float x\nproperty float y\nproperty float z\nend_header\n".format(
                        points.shape[0]))
                np.savetxt(f, points, fmt="%.6f %.6f %.6f")
            
            plane_features = model.sdf_model.pointnet.get_plane_features(
                point_cloud.cuda())  
            plane_features = torch.cat(plane_features, dim=1)
            recon = model.vae_model.generate(plane_features, atc.cuda())
            print("mesh filename: ", mesh_filename)

            mesh.create_mesh(model.sdf_model, recon, mesh_filename, N=256, max_batch=2 ** 16, from_plane_features=True)

            mesh_log_name = cls_name + "/" + mesh_name
            try:
                evaluate.main(point_cloud, mesh_filename, cd_file, mesh_log_name)
            except Exception as e:
                print(e)
            try:
                if not filter_threshold(mesh_filename, point_cloud, 0.005):
                    continue
                outdir = os.path.join(latent_dir, "{}/{}".format(cls_name, mesh_name))
                os.makedirs(outdir, exist_ok=True)
                features = model.sdf_model.pointnet.get_plane_features(point_cloud.cuda())
                features = torch.cat(features, dim=1)  
                latent = model.vae_model.get_latent(features)  
                np.savetxt(os.path.join(outdir, "latent.txt"), latent.cpu().numpy())
            except Exception as e:
                print(e)


@torch.no_grad()
def test_generation():
    
    if args.resume == 'finetune':  
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model = CombinedModel.load_from_checkpoint(specs["modulation_ckpt_path"], specs=specs, strict=False)
            ckpt = torch.load(specs["diffusion_ckpt_path"])
            new_state_dict = {}
            for k, v in ckpt['state_dict'].items():
                new_key = k.replace("diffusion_model.", "")
                new_state_dict[new_key] = v
            model.diffusion_model.load_state_dict(new_state_dict)

            model = model.cuda().eval()
    else:
        ckpt = "{}.ckpt".format(args.resume) if args.resume == 'last' else "epoch={}.ckpt".format(args.resume)
        resume = os.path.join(args.exp_dir, ckpt)
        model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()

    conditional = specs["diffusion_model_specs"]["cond"]

    if not conditional:
        samples = model.diffusion_model.generate_unconditional(args.num_samples)
        plane_features = model.vae_model.decode(samples)
        for i in range(len(plane_features)):
            plane_feature = plane_features[i].unsqueeze(0)
            mesh.create_mesh(model.sdf_model, plane_feature, recon_dir + "/{}_recon".format(i), N=128,
                             max_batch=2 ** 21, from_plane_features=True)

    else:
        test_split = [line.strip() for line in open(specs["TestSplit"], "r").readlines()]
        test_dataset = PCloader(specs["DataSource"], test_split, pc_size=specs.get("PCsize", 1024),
                                return_filename=True)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)

        with tqdm(test_dataloader) as pbar:
            for idx, data in enumerate(pbar):
                pbar.set_description("Files generated: {}/{}".format(idx, len(test_dataloader)))

                point_cloud, atc, filename = data  

                filename = filename[0]  

                cls_name = filename.split("/")[-4]
                mesh_name = filename.split("/")[-1].split('.')[0]
                outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
                os.makedirs(outdir, exist_ok=True)

                
                if args.filter:
                    threshold = 0.08
                    tmp_lst = []
                    count = 0
                    while len(tmp_lst) < args.num_samples:
                        count += 1
                        samples, perturbed_pc = model.diffusion_model.generate_from_pc(point_cloud.cuda(),
                                                                                       batch=args.num_samples,
                                                                                       save_pc=outdir,
                                                                                       return_pc=True)  
                        plane_features = model.vae_model.decode(samples)
                        perturbed_pc_pred = model.sdf_model.forward_with_plane_features(plane_features,
                                                                                        perturbed_pc.repeat(
                                                                                            args.num_samples, 1, 1))
                        consistency = F.l1_loss(perturbed_pc_pred, torch.zeros_like(perturbed_pc_pred),
                                                reduction='none')
                        loss = reduce(consistency, 'b ... -> b', 'mean',
                                      b=consistency.shape[0])  
                        
                        thresh_idx = loss <= threshold
                        tmp_lst.extend(plane_features[thresh_idx])

                        if count > 5:  
                            break

                    if len(tmp_lst) < 1:
                        continue
                    plane_features = tmp_lst[0:min(10, len(tmp_lst))]

                else:
                    samples, perturbed_pc = model.diffusion_model.generate_from_pc(point_cloud.cuda(),
                                                                                   batch=args.num_samples,
                                                                                   save_pc=outdir, return_pc=True,
                                                                                   perturb_pc=False)
                    atc_emb = model.vae_model.condition_encoder(atc.float().cuda().unsqueeze(dim=-1))
                    atc_emb = atc_emb.repeat(args.num_samples, 1)
                    samples_atc = torch.cat([samples, atc_emb], dim=1)
                    plane_features = model.vae_model.decode(samples)

                for i in range(len(plane_features)):
                    plane_feature = plane_features[i].unsqueeze(0)
                    mesh.create_mesh(model.sdf_model, plane_feature, outdir + "/{}_recon".format(i), N=128,
                                     max_batch=2 ** 10, from_plane_features=True)


if __name__ == "__main__":

    import argparse

    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument(
        "--exp_dir", "-e", required=True,
        help="This directory should include experiment specifications in 'specs_Art_test.json,' and logging will be done in this directory as well.",
    )
    arg_parser.add_argument(
        "--resume", "-r", default=None,
        help="continue from previous saved logs, integer value, 'last', or 'finetune'",
    )

    arg_parser.add_argument("--num_samples", "-n", default=5, type=int,
                            help='number of samples to generate and reconstruct')

    arg_parser.add_argument("--filter", default=False, help='whether to filter when sampling conditionally')
    arg_parser.add_argument("--class_name", "-c", default='laptop', type=str)

    args = arg_parser.parse_args()
    specs = json.load(open(os.path.join(args.exp_dir, "specs_Art_test.json")))
    print(specs["Description"])

    recon_dir = os.path.join(args.exp_dir, "recon")
    os.makedirs(recon_dir, exist_ok=True)

    if specs['training_task'] == 'modulation':
        latent_dir = os.path.join(args.exp_dir, "modulations")
        os.makedirs(latent_dir, exist_ok=True)
        test_modulations()
    elif specs['training_task'] == 'combined':
        test_generation()


